{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 递归神经网络的简洁实现\n",
    ":label:`sec_rnn-concise`\n",
    "\n",
    "而：numref:`sec_rnn_scratch` 对于了解rnn是如何实现的很有指导意义，\n",
    "这既不方便也不快捷。\n",
    "本节将展示如何更有效地实现相同的语言模型\n",
    "使用高级API提供的函数\n",
    "一个深入学习的框架。\n",
    "我们像以前一样从读取时间机器数据集开始。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/Animator.java\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/timemachine/Vocab.java\n",
    "%load ../utils/timemachine/RNNModelScratch.java\n",
    "%load ../utils/timemachine/TimeMachine.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.training.dataset.Record;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 在DJL中创建数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 DJL 中，处理数据集的理想而简洁的方法是使用内置的数据集，这些数据集可以轻松地围绕现有的 `Ndarray`，或者创建从 `RandomAccessDataset` 类扩展而来的自己的数据集。对于这一部分，我们将实现我们自己的。有关在 DJL 中创建自己的数据集的更多信息，请参阅：https://djl.ai/docs/development/how_to_use_dataset.html\n",
    "\n",
    "我们对 `TimeMachineDataset` 简洁地实现将替换先前创建的 `SeqDataLoader` 类。使用 DJL 格式的数据集，将允许我们使用已经内置的函数，这样我们就不必从头开始实现大多数功能。我们必须实现一个生成器、一个包含将数据保存到TimeMachineDataset对象的过程的prepare函数，以及一个get函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class TimeMachineDataset extends RandomAccessDataset {\n",
    "\n",
    "    private Vocab vocab;\n",
    "    private NDArray data;\n",
    "    private NDArray labels;\n",
    "    private int numSteps;\n",
    "    private int maxTokens;\n",
    "    private int batchSize;\n",
    "    private NDManager manager;\n",
    "    private boolean prepared;\n",
    "\n",
    "    public TimeMachineDataset(Builder builder) {\n",
    "        super(builder);\n",
    "        this.numSteps = builder.numSteps;\n",
    "        this.maxTokens = builder.maxTokens;\n",
    "        this.batchSize = builder.getSampler().getBatchSize();\n",
    "        this.manager = builder.manager;\n",
    "        this.data = this.manager.create(new Shape(0,35), DataType.INT32);\n",
    "        this.labels = this.manager.create(new Shape(0,35), DataType.INT32);\n",
    "        this.prepared = false;\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Record get(NDManager manager, long index) throws IOException {\n",
    "        NDArray X = data.get(new NDIndex(\"{}\", index));\n",
    "        NDArray Y = labels.get(new NDIndex(\"{}\", index));\n",
    "        return new Record(new NDList(X), new NDList(Y));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected long availableSize() {\n",
    "        return data.getShape().get(0);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public void prepare(Progress progress) throws IOException, TranslateException {\n",
    "        if (prepared) {\n",
    "            return;\n",
    "        }\n",
    "\n",
    "        Pair<List<Integer>, Vocab> corpusVocabPair = null;\n",
    "        try {\n",
    "            corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);\n",
    "        } catch (Exception e) {\n",
    "            e.printStackTrace(); // 在tokenize()函数期间，异常可能来自未知的token类型。\n",
    "        }\n",
    "        List<Integer> corpus = corpusVocabPair.getKey();\n",
    "        this.vocab = corpusVocabPair.getValue();\n",
    "\n",
    "        // 从一个随机偏移量（包括'numSteps-1'）开始到分区a\n",
    "        // 序列\n",
    "        int offset = new Random().nextInt(numSteps);\n",
    "        int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;\n",
    "        NDArray Xs =\n",
    "                manager.create(\n",
    "                        corpus.subList(offset, offset + numTokens).stream()\n",
    "                                .mapToInt(Integer::intValue)\n",
    "                                .toArray());\n",
    "        NDArray Ys =\n",
    "                manager.create(\n",
    "                        corpus.subList(offset + 1, offset + 1 + numTokens).stream()\n",
    "                                .mapToInt(Integer::intValue)\n",
    "                                .toArray());\n",
    "        Xs = Xs.reshape(new Shape(batchSize, -1));\n",
    "        Ys = Ys.reshape(new Shape(batchSize, -1));\n",
    "        int numBatches = (int) Xs.getShape().get(1) / numSteps;\n",
    "\n",
    "        NDList xNDList = new NDList();\n",
    "        NDList yNDList = new NDList();\n",
    "        for (int i = 0; i < numSteps * numBatches; i += numSteps) {\n",
    "            NDArray X = Xs.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n",
    "            NDArray Y = Ys.get(new NDIndex(\":, {}:{}\", i, i + numSteps));\n",
    "            xNDList.add(X);\n",
    "            yNDList.add(Y);\n",
    "        }\n",
    "        this.data = NDArrays.concat(xNDList);\n",
    "        xNDList.close();\n",
    "        this.labels = NDArrays.concat(yNDList);\n",
    "        yNDList.close();\n",
    "        this.prepared = true;\n",
    "    }\n",
    "\n",
    "    public Vocab getVocab() {\n",
    "        return this.vocab;\n",
    "    }\n",
    "\n",
    "    public static final class Builder extends BaseBuilder<Builder> {\n",
    "        int numSteps;\n",
    "        int maxTokens;\n",
    "        NDManager manager;\n",
    "\n",
    "        @Override\n",
    "        protected Builder self() { return this; }\n",
    "\n",
    "        public Builder setSteps(int steps) {\n",
    "            this.numSteps = steps;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder setMaxTokens(int maxTokens) {\n",
    "            this.maxTokens = maxTokens;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public Builder setManager(NDManager manager) {\n",
    "            this.manager = manager;\n",
    "            return this;\n",
    "        }\n",
    "\n",
    "        public TimeMachineDataset build() throws IOException, TranslateException {\n",
    "            TimeMachineDataset dataset = new TimeMachineDataset(this);\n",
    "            return dataset;\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因此，我们将更新上一节中函数 `predictCh8`, `trainCh8`, `trainEpochCh8`, 和 `gradClipping` 的代码，以包含数据集逻辑，并允许函数从DJL接受 `AbstractBlock` 而不是只接受 `RNNModelScratch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 在`prefix`后面生成新字符。 */\n",
    "public static String predictCh8(\n",
    "        String prefix,\n",
    "        int numPreds,\n",
    "        Object net,\n",
    "        Vocab vocab,\n",
    "        Device device,\n",
    "        NDManager manager) {\n",
    "\n",
    "    List<Integer> outputs = new ArrayList<>();\n",
    "    outputs.add(vocab.getIdx(\"\" + prefix.charAt(0)));\n",
    "    Functions.SimpleFunction<NDArray> getInput =\n",
    "            () ->\n",
    "                    manager.create(outputs.get(outputs.size() - 1))\n",
    "                            .toDevice(device, false)\n",
    "                            .reshape(new Shape(1, 1));\n",
    "\n",
    "    if (net instanceof RNNModelScratch) {\n",
    "        RNNModelScratch castedNet = (RNNModelScratch) net;\n",
    "        NDList state = castedNet.beginState(1, device);\n",
    "\n",
    "        for (char c : prefix.substring(1).toCharArray()) { // 预热期\n",
    "            state = (NDList) castedNet.forward(getInput.apply(), state).getValue();\n",
    "            outputs.add(vocab.getIdx(\"\" + c));\n",
    "        }\n",
    "\n",
    "        NDArray y;\n",
    "        for (int i = 0; i < numPreds; i++) {\n",
    "            Pair<NDArray, NDList> pair = castedNet.forward(getInput.apply(), state);\n",
    "            y = pair.getKey();\n",
    "            state = pair.getValue();\n",
    "\n",
    "            outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n",
    "        }\n",
    "    } else {\n",
    "        AbstractBlock castedNet = (AbstractBlock) net;\n",
    "        NDList state = null;\n",
    "        for (char c : prefix.substring(1).toCharArray()) { // 预热期\n",
    "            if (state == null) {\n",
    "                // Begin state\n",
    "                state =\n",
    "                        castedNet\n",
    "                                .forward(\n",
    "                                        new ParameterStore(manager, false),\n",
    "                                        new NDList(getInput.apply()),\n",
    "                                        false)\n",
    "                                .subNDList(1);\n",
    "            } else {\n",
    "                state =\n",
    "                        castedNet\n",
    "                                .forward(\n",
    "                                        new ParameterStore(manager, false),\n",
    "                                        new NDList(getInput.apply()).addAll(state),\n",
    "                                        false)\n",
    "                                .subNDList(1);\n",
    "            }\n",
    "            outputs.add(vocab.getIdx(\"\" + c));\n",
    "        }\n",
    "\n",
    "        NDArray y;\n",
    "        for (int i = 0; i < numPreds; i++) {\n",
    "            NDList pair =\n",
    "                    castedNet.forward(\n",
    "                            new ParameterStore(manager, false),\n",
    "                            new NDList(getInput.apply()).addAll(state),\n",
    "                            false);\n",
    "            y = pair.get(0);\n",
    "            state = pair.subNDList(1);\n",
    "\n",
    "            outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));\n",
    "        }\n",
    "    }\n",
    "\n",
    "    StringBuilder output = new StringBuilder();\n",
    "    for (int i : outputs) {\n",
    "        output.append(vocab.idxToToken.get(i));\n",
    "    }\n",
    "    return output.toString();\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 训练一个模型 */\n",
    "public static void trainCh8(\n",
    "        Object net,\n",
    "        RandomAccessDataset dataset,\n",
    "        Vocab vocab,\n",
    "        int lr,\n",
    "        int numEpochs,\n",
    "        Device device,\n",
    "        boolean useRandomIter,\n",
    "        NDManager manager)\n",
    "        throws IOException, TranslateException {\n",
    "    SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();\n",
    "    Animator animator = new Animator();\n",
    "\n",
    "    Functions.voidTwoFunction<Integer, NDManager> updater;\n",
    "    if (net instanceof RNNModelScratch) {\n",
    "        RNNModelScratch castedNet = (RNNModelScratch) net;\n",
    "        updater =\n",
    "                (batchSize, subManager) ->\n",
    "                        Training.sgd(castedNet.params, lr, batchSize, subManager);\n",
    "    } else {\n",
    "        // 已初始化网络\n",
    "        AbstractBlock castedNet = (AbstractBlock) net;\n",
    "        Model model = Model.newInstance(\"model\");\n",
    "        model.setBlock(castedNet);\n",
    "\n",
    "        Tracker lrt = Tracker.fixed(lr);\n",
    "        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "        DefaultTrainingConfig config =\n",
    "                new DefaultTrainingConfig(loss)\n",
    "                        .optOptimizer(sgd) // 优化器（损失函数）\n",
    "                        .optInitializer(\n",
    "                                new NormalInitializer(0.01f),\n",
    "                                Parameter.Type.WEIGHT) // 设置初始值设定项\n",
    "                        .optDevices(Engine.getInstance().getDevices(1)) // 设置所需的GPU数量\n",
    "                        .addEvaluator(new Accuracy()) // 模型精度\n",
    "                        .addTrainingListeners(TrainingListener.Defaults.logging()); // 日志\n",
    "\n",
    "        Trainer trainer = model.newTrainer(config);\n",
    "        updater = (batchSize, subManager) -> trainer.step();\n",
    "    }\n",
    "\n",
    "    Function<String, String> predict =\n",
    "            (prefix) -> predictCh8(prefix, 50, net, vocab, device, manager);\n",
    "    // 训练和预测\n",
    "    double ppl = 0.0;\n",
    "    double speed = 0.0;\n",
    "    for (int epoch = 0; epoch < numEpochs; epoch++) {\n",
    "        Pair<Double, Double> pair =\n",
    "                trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager);\n",
    "        ppl = pair.getKey();\n",
    "        speed = pair.getValue();\n",
    "        if ((epoch + 1) % 10 == 0) {\n",
    "           animator.add(epoch + 1, (float) ppl, \"ppl\");\n",
    "           animator.show();\n",
    "        }\n",
    "    }\n",
    "    System.out.format(\n",
    "            \"perplexity: %.1f, %.1f tokens/sec on %s%n\", ppl, speed, device.toString());\n",
    "    System.out.println(predict.apply(\"time traveller\"));\n",
    "    System.out.println(predict.apply(\"traveller\"));\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 在一个epoch内训练一个模型 */\n",
    "public static Pair<Double, Double> trainEpochCh8(\n",
    "        Object net,\n",
    "        RandomAccessDataset dataset,\n",
    "        Loss loss,\n",
    "        Functions.voidTwoFunction<Integer, NDManager> updater,\n",
    "        Device device,\n",
    "        boolean useRandomIter,\n",
    "        NDManager manager)\n",
    "        throws IOException, TranslateException {\n",
    "    StopWatch watch = new StopWatch();\n",
    "    watch.start();\n",
    "    Accumulator metric = new Accumulator(2); // 训练损失总和，tokens\n",
    "\n",
    "    try (NDManager childManager = manager.newSubManager()) {\n",
    "        NDList state = null;\n",
    "        for (Batch batch : dataset.getData(childManager)) {\n",
    "            NDArray X = batch.getData().head().toDevice(device, true);\n",
    "            NDArray Y = batch.getLabels().head().toDevice(device, true);\n",
    "            if (state == null || useRandomIter) {\n",
    "                // 在第一次迭代或\n",
    "                // 使用随机抽样\n",
    "                if (net instanceof RNNModelScratch) {\n",
    "                    state =\n",
    "                            ((RNNModelScratch) net)\n",
    "                                    .beginState((int) X.getShape().getShape()[0], device);\n",
    "                }\n",
    "            } else {\n",
    "                for (NDArray s : state) {\n",
    "                    s.stopGradient();\n",
    "                }\n",
    "            }\n",
    "            if (state != null) {\n",
    "                state.attach(childManager);\n",
    "            }\n",
    "\n",
    "            NDArray y = Y.transpose().reshape(new Shape(-1));\n",
    "            X = X.toDevice(device, false);\n",
    "            y = y.toDevice(device, false);\n",
    "            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {\n",
    "                NDArray yHat;\n",
    "                if (net instanceof RNNModelScratch) {\n",
    "                    Pair<NDArray, NDList> pairResult = ((RNNModelScratch) net).forward(X, state);\n",
    "                    yHat = pairResult.getKey();\n",
    "                    state = pairResult.getValue();\n",
    "                } else {\n",
    "                    NDList pairResult;\n",
    "                    if (state == null) {\n",
    "                        // 开始状态\n",
    "                        pairResult =\n",
    "                                ((AbstractBlock) net)\n",
    "                                        .forward(\n",
    "                                                new ParameterStore(manager, false),\n",
    "                                                new NDList(X),\n",
    "                                                true);\n",
    "                    } else {\n",
    "                        pairResult =\n",
    "                                ((AbstractBlock) net)\n",
    "                                        .forward(\n",
    "                                                new ParameterStore(manager, false),\n",
    "                                                new NDList(X).addAll(state),\n",
    "                                                true);\n",
    "                    }\n",
    "                    yHat = pairResult.get(0);\n",
    "                    state = pairResult.subNDList(1);\n",
    "                }\n",
    "\n",
    "                NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();\n",
    "                gc.backward(l);\n",
    "                metric.add(new float[] {l.getFloat() * y.size(), y.size()});\n",
    "            }\n",
    "            gradClipping(net, 1, childManager);\n",
    "            updater.apply(1, childManager); // 因为已经调用了“mean”函数\n",
    "        }\n",
    "    }\n",
    "    return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "/** 修剪梯度 */\n",
    "public static void gradClipping(Object net, int theta, NDManager manager) {\n",
    "    double result = 0;\n",
    "    NDList params;\n",
    "    if (net instanceof RNNModelScratch) {\n",
    "        params = ((RNNModelScratch) net).params;\n",
    "    } else {\n",
    "        params = new NDList();\n",
    "        for (Pair<String, Parameter> pair : ((AbstractBlock) net).getParameters()) {\n",
    "            params.add(pair.getValue().getArray());\n",
    "        }\n",
    "    }\n",
    "    for (NDArray p : params) {\n",
    "        NDArray gradient = p.getGradient().stopGradient();\n",
    "        gradient.attach(manager);\n",
    "        result += gradient.pow(2).sum().getFloat();\n",
    "    }\n",
    "    double norm = Math.sqrt(result);\n",
    "    if (norm > theta) {\n",
    "        for (NDArray param : params) {\n",
    "            NDArray gradient = param.getGradient();\n",
    "            gradient.muli(theta / norm);\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们将利用刚刚创建的数据集并分配所需的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 32;\n",
    "int numSteps = 35;\n",
    "\n",
    "TimeMachineDataset dataset = new TimeMachineDataset.Builder()\n",
    "        .setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)\n",
    "        .setSteps(numSteps).build();\n",
    "dataset.prepare();\n",
    "Vocab vocab = dataset.getVocab();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义模型\n",
    "\n",
    "高级API提供递归神经网络的实现。\n",
    "我们构造了一个包含单个隐层和256个隐单元的递归神经网络层`rnn_layer`。\n",
    "事实上，我们甚至还没有讨论多层的含义——这将发生在 :numref:`sec_deep_rnn`.\n",
    "现在，只需说多个层相当于一个RNN层的输出被用作下一个 RNN 层的输入即可。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numHiddens = 256;\n",
    "RNN rnnLayer = RNN.builder().setNumLayers(1)\n",
    "        .setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "初始化隐藏状态很简单。\n",
    "我们调用成员函数 `beginState` (在DJL中，我们不必在以后第一次运行 `forward` 时运行 `beginState` 来指定结果状态，因为在我们第一次运行`forward`, 时，DJL会运行此逻辑，但出于演示目的，我们将在此处创建此逻辑)。\n",
    "这将返回一个列表 (`state`)\n",
    "包含\n",
    "初始隐藏状态\n",
    "对于minibatch中的每个示例，\n",
    "是谁的形状\n",
    "（隐藏层的数量、批次大小、隐藏单元的数量）。\n",
    "对于某些型号\n",
    "待稍后介绍\n",
    "(例如，长-短期记忆),\n",
    "这样一份名单也不例外\n",
    "包含其他信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList beginState(int batchSize, int numLayers, int numHiddens) {\n",
    "    return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));\n",
    "}\n",
    "\n",
    "NDList state = beginState(batchSize, 1, numHiddens);\n",
    "System.out.println(state.size());\n",
    "System.out.println(state.get(0).getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "具有隐藏状态和输入，\n",
    "我们可以用\n",
    "更新后的隐藏状态。\n",
    "应该强调的是\n",
    "the \"output\" (`Y`) of `rnnLayer`\n",
    "*不* 涉及输出层的计算：\n",
    "指\n",
    "*每个* 时间步骤的隐藏状态，\n",
    "它们可以作为输入\n",
    "到后续的输出层。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，\n",
    "(`stateNew`) 返回的更新的隐藏状态 `rnnLayer`\n",
    "指隐藏状态\n",
    "在小批量的*最后*时间步。\n",
    "它可用于初始化\n",
    "一个epoch下一个小批量的隐藏状态\n",
    "在顺序分区中。\n",
    "对于多个隐藏层，\n",
    "将存储每个层的隐藏状态\n",
    "在这个变量(`stateNew`)中。\n",
    "对于某些型号\n",
    "待稍后介绍\n",
    "(例如，长-短期记忆),\n",
    "这个变量也是\n",
    "包含其他信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));\n",
    "\n",
    "NDList input = new NDList(X, state.get(0));\n",
    "rnnLayer.initialize(manager, DataType.FLOAT32, input.getShapes());\n",
    "NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), input, false);\n",
    "NDArray Y = forwardOutput.get(0);\n",
    "NDArray stateNew = forwardOutput.get(1);\n",
    "\n",
    "System.out.println(Y.getShape());\n",
    "System.out.println(stateNew.getShape());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "类似于 :numref:`sec_rnn_scratch`,\n",
    "我们定义了一个 `RNNModel` 类 \n",
    "对于完整的RNN模型。\n",
    "注意 `rnnLayer` 只包含隐藏的重复层，我们需要创建一个单独的输出层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class RNNModel extends AbstractBlock {\n",
    "\n",
    "    private RNN rnnLayer;\n",
    "    private Linear dense;\n",
    "    private int vocabSize;\n",
    "\n",
    "    public RNNModel(RNN rnnLayer, int vocabSize) {\n",
    "        this.rnnLayer = rnnLayer;\n",
    "        this.addChildBlock(\"rnn\", rnnLayer);\n",
    "        this.vocabSize = vocabSize;\n",
    "        this.dense = Linear.builder().setUnits(vocabSize).build();\n",
    "        this.addChildBlock(\"linear\", dense);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {\n",
    "        NDArray X = inputs.get(0).transpose().oneHot(vocabSize);\n",
    "        inputs.set(0, X);\n",
    "        NDList result = rnnLayer.forward(parameterStore, inputs, training);\n",
    "        NDArray Y = result.get(0);\n",
    "        NDArray state = result.get(1);\n",
    "\n",
    "        int shapeLength = Y.getShape().dimension();\n",
    "        NDList output = dense.forward(parameterStore, new NDList(Y\n",
    "                .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);\n",
    "        return new NDList(output.get(0), state);\n",
    "    }\n",
    "    \n",
    "    @Override\n",
    "    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        Shape shape = rnnLayer.getOutputShapes(new Shape[]{inputShapes[0]})[0];\n",
    "        dense.initialize(manager, dataType, new Shape(vocabSize, shape.get(shape.dimension() - 1)));\n",
    "    }\n",
    "\n",
    "    /* 我们不会实现它，因为我们不会使用它，但它是抽象块的一部分  */\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputShapes) {\n",
    "        return new Shape[0];\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练与预测\n",
    "\n",
    "在训练模型之前，让我们使用具有随机权重的模型进行预测。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Device device = manager.getDevice();\n",
    "RNNModel net = new RNNModel(rnnLayer, vocab.length());\n",
    "net.initialize(manager, DataType.FLOAT32, X.getShape());\n",
    "predictCh8(\"time traveller\", 10, net, vocab, device, manager);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "很明显，这种模式根本不起作用。\n",
    "接下来，我们使用在:numref:`sec_rnn_scratch` 中定义的相同超参数调用 `trainCh8`，并使用高级API训练我们的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 500);\n",
    "\n",
    "int lr = 1;\n",
    "trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "与上一节相比，由于深度学习框架的高级API对代码进行了更多的优化，\n",
    "该模型在较短的时间内达到了较低的困惑度。\n",
    "\n",
    "\n",
    "## 总结\n",
    "\n",
    "* 深度学习框架的高级API提供了RNN层的实现。\n",
    "* 高级API的RNN层返回输出和更新的隐藏状态，其中输出不涉及输出层计算。\n",
    "* 与从头开始使用其实现相比，使用高级API可以更快地进行RNN训练。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 尝试使用高级API，你能使循环神经网络模型过拟合吗？\n",
    "2. 如果在RNN模型中增加隐藏层的数量，会发生什么情况？你能使模型工作吗？\n",
    "3. 使用RNN实现 :numref:`sec_sequence` 的自回归模型。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
